import torch
import torch.nn as nn
import numpy as np

def create_attrs(topK, keepTopK):
    attrs = {}
    attrs["shareLocation"] = 1 # Default for yolor
    attrs["backgroundLabelId"] = -1
    attrs["numClasses"] = 80
    attrs["topK"] = topK
    attrs["keepTopK"] = keepTopK
    attrs["scoreThreshold"] = 0.4
    attrs["iouThreshold"] = 0.6
    attrs["isNormalized"] = 0 # Default yolor did not perform normalization
    attrs["clipBoxes"] = 0
    attrs["plugin_version"] = "1"
    return attrs

def create_and_add_plugin_node(graph, topK, keepTopK):
    
    batch_size = graph.inputs[0].shape[0]
    n_boxes = graph.inputs[0].shape[1]

    tensors = graph.tensors()
    boxes_tensor = tensors["bboxes"]
    confs_tensor = tensors["scores"]


    num_detections = gs.Variable(name="num_detections").to_variable(dtype=np.int32, shape=[-1, 1])
    nmsed_boxes = gs.Variable(name="nmsed_boxes").to_variable(dtype=np.float32, shape=[-1, keepTopK, 4])
    nmsed_scores = gs.Variable(name="nmsed_scores").to_variable(dtype=np.float32, shape=[-1, keepTopK])
    nmsed_classes = gs.Variable(name="nmsed_classes").to_variable(dtype=np.float32, shape=[-1, keepTopK])

    new_outputs = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes]

    mns_node = gs.Node(
        op="BatchedNMSDynamic_TRT",
        attrs=create_attrs(topK, keepTopK),
        inputs=[boxes_tensor, confs_tensor],
        outputs=new_outputs)

    graph.nodes.append(mns_node)
    graph.outputs = new_outputs

    return graph.cleanup(remove_unused_node_outputs=True).toposort()



class YOLORPostProcessor(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, output):
        """ Split output [n_batch, n_bboxes, 85] to 3 output: bboxes, scores, classes
        """ 
        # x, y, w, h -> x1, y1, x2, y2
        bboxes_x = output[..., 0:1]
        bboxes_y = output[..., 1:2]
        bboxes_w = output[..., 2:3]
        bboxes_h = output[..., 3:4]
        bboxes_x1 = bboxes_x - bboxes_w/2
        bboxes_y1 = bboxes_y - bboxes_h/2
        bboxes_x2 = bboxes_x + bboxes_w/2
        bboxes_y2 = bboxes_y + bboxes_h/2
        bboxes = torch.cat([bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2], dim = -1)
        bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
        obj_conf = output[..., 4:5]
        cls_conf = output[..., 5:]
        scores   = obj_conf * cls_conf # conf = obj_conf * cls_conf
        print(bboxes.shape, scores.shape)
        return bboxes, scores

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="Add batchedNMSPlugin")
    parser.add_argument("-f", "--model", help="Path to the ONNX model generated by convert_to_onnx.py", default="yolor_csp_x_star.onnx")
    parser.add_argument("-t", "--topK", help="number of bounding boxes for nms", default=1000)
    parser.add_argument("-k", "--keepTopK", help="bounding boxes to be kept per image", default=200)

    args, _ = parser.parse_known_args()


    model = YOLORPostProcessor()
    model.eval()
    sample_output  = torch.rand(size = (1, 25200, 85))
    bboxes, scores = model(sample_output)
    print(bboxes.shape, scores.shape)

    # Export the model
    torch.onnx.export(model,               # model being run
                      (sample_output,),                         # model input (or a tuple for multiple inputs)
                      "yolo_post_process.onnx",   # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=11,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['output'],   # the model's input names
                      output_names = ['bboxes', 'scores'], # the model's output names
                      dynamic_axes={'bboxes' : [0, 1],    # dynamic batch & number of boxes
                                    'scores' : [0, 1]})

    import onnx
    import onnx_graphsurgeon as gs
    model1 = onnx.load(args.model)

    model2 = onnx.load('yolo_post_process.onnx')

    combined_model = onnx.compose.merge_models(
        model1, model2,
        io_map=[('output', 'output')]
                )
    onnx.checker.check_model(combined_model)
    graph = gs.import_onnx(combined_model)
    graph = create_and_add_plugin_node(graph, args.topK, args.keepTopK)
    graph.cleanup(remove_unused_node_outputs=True).toposort()
    onnx.save(gs.export_onnx(graph), args.model.replace(".onnx", "-nms.onnx"))
